{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Z8YbxIz9FpSE"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "from tensorflow.keras.layers import Input, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape\n",
        "from tensorflow.keras.models import Model\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import os\n",
        "import urllib.request\n",
        "import tarfile\n",
        "\n",
        "# Download STL-10 dataset\n",
        "url = 'http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz'\n",
        "file_name = 'stl10_binary.tar.gz'\n",
        "\n",
        "if not os.path.exists(file_name):\n",
        "    urllib.request.urlretrieve(url, file_name)\n",
        "\n",
        "# Extract the dataset\n",
        "tar = tarfile.open(file_name, \"r:gz\")\n",
        "tar.extractall()\n",
        "tar.close()\n",
        "\n",
        "# Load the dataset\n",
        "data_dir = 'stl10_binary'\n",
        "file_names = ['train_X.bin', 'train_y.bin', 'test_X.bin', 'test_y.bin']\n",
        "\n",
        "x_train_path = os.path.join(data_dir, file_names[0])\n",
        "y_train_path = os.path.join(data_dir, file_names[1])\n",
        "x_test_path = os.path.join(data_dir, file_names[2])\n",
        "y_test_path = os.path.join(data_dir, file_names[3])\n",
        "\n",
        "x_train = np.fromfile(x_train_path, dtype=np.uint8).reshape(-1, 3, 96, 96).transpose(0, 2, 3, 1)\n",
        "y_train = np.fromfile(y_train_path, dtype=np.uint8) - 1  # Class labels range from 1 to 10, so subtract 1\n",
        "x_test = np.fromfile(x_test_path, dtype=np.uint8).reshape(-1, 3, 96, 96).transpose(0, 2, 3, 1)\n",
        "y_test = np.fromfile(y_test_path, dtype=np.uint8) - 1  # Class labels range from 1 to 10, so subtract 1\n",
        "\n",
        "\n",
        "# Preprocessing: normalize the data\n",
        "x_train = x_train.astype('float32') / 255.0\n",
        "x_test = x_test.astype('float32') / 255.0\n",
        "\n",
        "# Define the VAE architecture\n",
        "latent_dim = 128\n",
        "\n",
        "# Encoder\n",
        "inputs = Input(shape=(96, 96, 3))\n",
        "x = Conv2D(32, (3, 3), activation='relu', strides=(2, 2), padding='same')(inputs)\n",
        "x = Conv2D(64, (3, 3), activation='relu', strides=(2, 2), padding='same')(x)\n",
        "x = Flatten()(x)\n",
        "x = Dense(256, activation='relu')(x)\n",
        "\n",
        "# Latent space\n",
        "z_mean = Dense(latent_dim)(x)\n",
        "z_log_var = Dense(latent_dim)(x)\n",
        "\n",
        "# Reparameterization trick\n",
        "\n",
        "def sampling(args):\n",
        "    z_mean, z_log_var = args\n",
        "    epsilon = tf.random.normal(shape=(tf.shape(z_mean)[0], latent_dim))\n",
        "    return z_mean + tf.exp(0.5 * z_log_var) * epsilon\n",
        "\n",
        "z = tf.keras.layers.Lambda(sampling)([z_mean, z_log_var])\n",
        "\n",
        "# Decoder\n",
        "decoder_inputs = Input(shape=(latent_dim,))\n",
        "x = Dense(6 * 6 * 64, activation='relu')(decoder_inputs)\n",
        "x = Reshape((6, 6, 64))(x)\n",
        "x = Conv2DTranspose(64, (3, 3), activation='relu', strides=(2, 2), padding='same')(x)\n",
        "x = Conv2DTranspose(32, (3, 3), activation='relu', strides=(2, 2), padding='same')(x)\n",
        "outputs = Conv2DTranspose(3, (3, 3), activation='sigmoid', padding='same')(x)\n",
        "\n",
        "# VAE model\n",
        "encoder = Model(inputs, z_mean)\n",
        "decoder = Model(decoder_inputs, outputs)\n",
        "vae_output = decoder(z)\n",
        "vae = Model(inputs, vae_output)\n",
        "\n",
        "# Reshape inputs and outputs\n",
        "\n",
        "inputs_reshaped = tf.image.resize(inputs, (24, 24))\n",
        "vae_output_reshaped = tf.image.resize(vae_output, (24, 24))\n",
        "\n",
        "# Define the loss function\n",
        "reconstruction_loss = tf.keras.losses.binary_crossentropy(tf.reshape(inputs_reshaped, (-1, 24 * 24 * 3)),\n",
        "                                                          tf.reshape(vae_output_reshaped, (-1, 24 * 24 * 3)))\n",
        "\n",
        "reconstruction_loss *= 24 * 24 * 3\n",
        "kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)\n",
        "kl_loss = tf.reduce_mean(kl_loss, axis=-1)\n",
        "kl_loss *= -0.5\n",
        "\n",
        "vae_loss = tf.reduce_mean(reconstruction_loss + kl_loss)\n",
        "\n",
        "# Compile the model\n",
        "vae.add_loss(vae_loss)\n",
        "vae.compile(optimizer='adam')\n",
        "\n",
        "# Train the model\n",
        "epochs = 100\n",
        "batch_size = 128\n",
        "history = vae.fit(x_train, epochs=epochs, batch_size=batch_size, validation_data=(x_test, None))\n",
        "\n",
        "# Store the training and validation loss (reconstruction loss) in the history object\n",
        "history.history['train_reconstruction_loss'] = history.history['loss']\n",
        "history.history['val_reconstruction_loss'] = history.history['val_loss']\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import torchvision\n",
        "import torchvision.transforms as transforms\n",
        "from torch.utils.data import DataLoader, random_split\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# Step 1: Data Preprocessing\n",
        "transform = transforms.Compose([\n",
        "  transforms.ToTensor(),\n",
        "  transforms.Resize((32, 32)),  # Resize the images to 32x32\n",
        "  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # Normalize the data\n",
        "])\n",
        "\n",
        "stl10_dataset = torchvision.datasets.STL10(root='./data', split='train', download=True, transform=transform)\n",
        "\n",
        "# Divide data into training and validation sets\n",
        "train_size = int(0.8 * len(stl10_dataset))\n",
        "val_size = len(stl10_dataset) - train_size\n",
        "train_dataset, val_dataset = random_split(stl10_dataset, [train_size, val_size])\n",
        "\n",
        "trainloader = DataLoader(train_dataset, batch_size=64, shuffle=True)\n",
        "valloader = DataLoader(val_dataset, batch_size=64, shuffle=False)\n",
        "\n",
        "# Step 2: Model Architecture\n",
        "class VAE(nn.Module):\n",
        "    def __init__(self, latent_dim=64):\n",
        "        super(VAE, self).__init__()\n",
        "self.latent_dim = latent_dim\n",
        "\n",
        "        # Encoder layers\n",
        "self.encoder = nn.Sequential(\n",
        "            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1), nn.ReLU(),\n",
        "            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1), nn.ReLU(),\n",
        "            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), nn.ReLU(),\n",
        "            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), nn.ReLU(),\n",
        "            nn.Flatten(),\n",
        "            nn.Linear(256 * 2 * 2, 1024),\n",
        "            nn.ReLU(),\n",
        "            nn.Linear(1024, self.latent_dim * 2),\n",
        "        )\n",
        "\n",
        "        # Decoder layers\n",
        "self.decoder = nn.Sequential(\n",
        "nn.Linear(self.latent_dim, 1024),\n",
        "nn.ReLU(),\n",
        "nn.Linear(1024, 256 * 2 * 2),\n",
        "nn.ReLU(),\n",
        "nn.Unflatten(1, (256, 2, 2)),\n",
        "            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),\n",
        "nn.ReLU(),\n",
        "            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),\n",
        "nn.ReLU(),\n",
        "            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),\n",
        "nn.ReLU(),\n",
        "            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),\n",
        "nn.Tanh(),  # To map output to [-1, 1] range for images with normalized data\n",
        "        )\n",
        "\n",
        "    def encode(self, x):\n",
        "        x = self.encoder(x)\n",
        "        mu = x[:, :self.latent_dim]\n",
        "logvar = x[:, self.latent_dim:]\n",
        "        return mu, logvar\n",
        "\n",
        "    def reparameterize(self, mu, logvar):\n",
        "        std = torch.exp(0.5 * logvar)\n",
        "        eps = torch.randn_like(std)\n",
        "        z = mu + eps * std\n",
        "        return z\n",
        "\n",
        "    def decode(self, z):\n",
        "        return self.decoder(z)\n",
        "\n",
        "    def forward(self, x):\n",
        "        mu, logvar = self.encode(x)\n",
        "        z = self.reparameterize(mu, logvar)\n",
        "        return self.decode(z), mu, logvar\n",
        "\n",
        "# Step 3: Loss Function with Regularizer (KL Divergence)\n",
        "def vae_loss(recon_x, x, mu, logvar):\n",
        "    # Reconstruction Loss (MSE for images)\n",
        "reconstruction_loss = nn.MSELoss()(recon_x, x)\n",
        "    # KL Divergence Loss (Regularizer)\n",
        "kl_divergence_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())\n",
        "    return reconstruction_loss + kl_divergence_loss\n",
        "\n",
        "# Step 4: Training with Weight Decay and Learning Rate Scheduler\n",
        "def train_vae(model, trainloader, valloader, optimizer, num_epochs=10):\n",
        "model.train()\n",
        "    losses = []\n",
        "val_losses = []\n",
        "\n",
        "    for epoch in range(num_epochs):\n",
        "running_loss = 0.0\n",
        "        for i, data in enumerate(trainloader, 0):\n",
        "            inputs, _ = data\n",
        "            inputs = inputs.to(device)\n",
        "\n",
        "optimizer.zero_grad()\n",
        "\n",
        "recon_batch, mu, logvar = model(inputs)\n",
        "            loss = vae_loss(recon_batch, inputs, mu, logvar)\n",
        "\n",
        "loss.backward()\n",
        "optimizer.step()\n",
        "\n",
        "running_loss += loss.item()\n",
        "\n",
        "epoch_loss = running_loss / len(trainloader)\n",
        "losses.append(epoch_loss)\n",
        "\n",
        "        # Validation loss\n",
        "model.eval()\n",
        "        with torch.no_grad():\n",
        "val_loss = 0.0\n",
        "            for data in valloader:\n",
        "                inputs, _ = data\n",
        "                inputs = inputs.to(device)\n",
        "recon_batch, mu, logvar = model(inputs)\n",
        "val_loss += vae_loss(recon_batch, inputs, mu, logvar).item()\n",
        "\n",
        "val_loss /= len(valloader)\n",
        "val_losses.append(val_loss)\n",
        "\n",
        "model.train()\n",
        "\n",
        "        print(f\"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss}, Val Loss: {val_loss}\")\n",
        "\n",
        "    return losses, val_losses\n",
        "\n",
        "# Step 5: Sampling and Generation\n",
        "def generate_images(model, num_images=10):\n",
        "model.eval()\n",
        "    with torch.no_grad():\n",
        "        z = torch.randn(num_images, model.latent_dim).to(device)\n",
        "generated_images = model.decode(z).cpu()\n",
        "        return generated_images\n",
        "\n",
        "def plot_generated_vs_original(generated_images, original_images):\n",
        "    fig, axes = plt.subplots(2, len(generated_images), figsize=(15, 5))\n",
        "    for i, img in enumerate(generated_images):\n",
        "img_gen = img.permute(1, 2, 0)  # Transpose to (H, W, C)\n",
        "img_gen = (img_gen + 1) / 2.0  # De-normalize from [-1, 1] to [0, 1]\n",
        "        axes[0, i].imshow(img_gen)\n",
        "        axes[0, i].axis('off')\n",
        "\n",
        "    for i, img in enumerate(original_images[:len(generated_images)]):\n",
        "img_orig = img.permute(1, 2, 0)  # Transpose to (H, W, C)\n",
        "img_orig = (img_orig + 1) / 2.0  # De-normalize from [-1, 1] to [0, 1]\n",
        "        axes[1, i].imshow(img_orig)\n",
        "        axes[1, i].axis('off')\n",
        "\n",
        "    axes[0, 0].set_title('Generated Images')\n",
        "    axes[1, 0].set_title('Original Images')\n",
        "plt.show()\n",
        "\n",
        "# Main\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "vae_model = VAE(latent_dim=64).to(device)\n",
        "optimizer = optim.Adam(vae_model.parameters(), lr=0.001, weight_decay=1e-5)\n",
        "scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)  # Learning rate scheduler\n",
        "\n",
        "# Train VAE with losses recorded\n",
        "train_losses, val_losses = train_vae(vae_model, trainloader, valloader, optimizer, num_epochs=10)\n",
        "\n",
        "# Plot the learning curve\n",
        "def plot_learning_curve(train_losses, val_losses):\n",
        "plt.figure()\n",
        "    epochs = range(1, len(train_losses) + 1)\n",
        "plt.plot(epochs, train_losses, '-o', label='Train Loss')\n",
        "plt.plot(epochs, val_losses, '-o', label='Val Loss')\n",
        "plt.xlabel('Epoch')\n",
        "plt.ylabel('Loss')\n",
        "plt.legend()\n",
        "plt.title('Variational Autoencoder Learning Curve')\n",
        "plt.show()\n",
        "\n",
        "plot_learning_curve(train_losses, val_losses)\n",
        "\n",
        "# Generate and display images\n",
        "generated_images = generate_images(vae_model, num_images=10)\n",
        "original_images = []\n",
        "for i, data in enumerate(valloader, 0):\n",
        "    inputs, _ = data\n",
        "original_images.append(inputs[0])\n",
        "    if i>= 9:  # Display 10 original images\n",
        "        break\n",
        "\n",
        "plot_generated_vs_original(generated_images, original_images)\n"
      ],
      "metadata": {
        "id": "NOAFFc1EHfj9"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "from vit_keras import vit, utils\n",
        "\n",
        "# Load the ViT model\n",
        "image_size = 384\n",
        "classes = utils.get_imagenet_classes()\n",
        "model = vit.vit_b16(image_size=image_size, pretrained=True, include_top=True, pretrained_top=True)\n",
        "\n",
        "# Load an image\n",
        "url = 'https://upload.wikimedia.org/wikipedia/commons/d/d7/Granny_smith_and_cross_section.jpg'\n",
        "image = utils.read(url, image_size)\n",
        "\n",
        "# Preprocess the image\n",
        "X = vit.preprocess_inputs(image).reshape(1, image_size, image_size, 3)\n",
        "\n",
        "# Make a prediction\n",
        "y = model.predict(X)\n",
        "\n",
        "# Print the predicted class\n",
        "print(classes[y[0].argmax()])\n"
      ],
      "metadata": {
        "id": "x590_rH7Ht4r"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install tensorflow-addons\n",
        "!pip install vit_keras\n",
        "\n",
        "# Load the dataset\n",
        "x_train_path = os.path.join(data_dir, file_names[0])\n",
        "y_train_path = os.path.join(data_dir, file_names[1])\n",
        "x_test_path = os.path.join(data_dir, file_names[2])\n",
        "y_test_path = os.path.join(data_dir, file_names[3])\n",
        "\n",
        "x_train = np.fromfile(x_train_path, dtype=np.uint8).reshape(-1, 3, 96, 96).transpose(0, 2, 3, 1)\n",
        "y_train = np.fromfile(y_train_path, dtype=np.uint8) - 1\n",
        "x_test = np.fromfile(x_test_path, dtype=np.uint8).reshape(-1, 3, 96, 96).transpose(0, 2, 3, 1)\n",
        "y_test = np.fromfile(y_test_path, dtype=np.uint8) - 1\n",
        "\n",
        "# Preprocessing: normalize the data\n",
        "x_train = x_train.astype('float32') / 255.0\n",
        "x_test = x_test.astype('float32') / 255.0\n",
        "\n",
        "# Define the VIT Autoencoder architecture\n",
        "latent_dim = 128\n",
        "\n",
        "# Encoder (Vision Transformer)\n",
        "inputs = Input(shape=(96, 96, 3))\n",
        "x = vit.vit_l32(image_size=96, activation='gelu', pretrained=False, include_top=False, pretrained_top=False)(inputs)\n",
        "x = Reshape((-1, x.shape[-1]))(x)  # Flatten the sequence of patches\n",
        "x = tf.keras.layers.GlobalAveragePooling1D()(x)  # Reduce sequence to a single vector\n",
        "latent_space = Dense(latent_dim, activation='relu')(x)  # Dense layer for the latent representation\n",
        "\n",
        "# Decoder\n",
        "decoder_inputs = Input(shape=(latent_dim,))\n",
        "x = Dense(6 * 6 * 32, activation='relu')(decoder_inputs)\n",
        "x = Reshape((6, 6, 32))(x)\n",
        "x = Conv2DTranspose(32, (3, 3), activation='relu', strides=(2, 2), padding='same')(x)\n",
        "x = Conv2DTranspose(16, (3, 3), activation='relu', strides=(2, 2), padding='same')(x)\n",
        "outputs = Conv2DTranspose(3, (3, 3), activation='sigmoid', padding='same')(x)\n",
        "\n",
        "# VAE model\n",
        "encoder = Model(inputs, latent_space)\n",
        "decoder = Model(decoder_inputs, outputs)\n",
        "\n",
        "# Create the autoencoder by connecting the encoder and decoder\n",
        "autoencoder_output = decoder(encoder(inputs))\n",
        "autoencoder = Model(inputs, autoencoder_output)\n",
        "\n",
        "# Reshape inputs and outputs\n",
        "inputs_reshaped = tf.image.resize(inputs, (24, 24))\n",
        "autoencoder_output_reshaped = tf.image.resize(autoencoder_output, (24, 24))\n",
        "\n",
        "# Define the loss function (Autoencoder loss)\n",
        "reconstruction_loss = tf.keras.losses.mean_squared_error(tf.reshape(inputs_reshaped, (-1, 24 * 24 * 3)),\n",
        "                                                         tf.reshape(autoencoder_output_reshaped, (-1, 24 * 24 * 3)))\n",
        "\n",
        "autoencoder_loss = tf.reduce_mean(reconstruction_loss)\n",
        "\n",
        "# Compile the model\n",
        "autoencoder.add_loss(autoencoder_loss)\n",
        "autoencoder.compile(optimizer='adam')\n",
        "\n",
        "# Train the model\n",
        "epochs = 100\n",
        "batch_size = 128\n",
        "history = autoencoder.fit(x_train, epochs=epochs, batch_size=batch_size, validation_data=(x_test, None))\n",
        "\n",
        "# Plot the learning curves for loss\n",
        "plt.figure(figsize=(10, 6))\n",
        "plt.plot(history.history['loss'], label='Train Loss')\n",
        "plt.plot(history.history['val_loss'], label='Validation Loss')\n",
        "plt.xlabel('Epochs')\n",
        "plt.ylabel('Loss')\n",
        "plt.title('Training and Validation Loss')\n",
        "plt.legend()\n",
        "plt.grid()\n",
        "plt.show()\n",
        "\n",
        "# Generate and plot some reconstructed images using the VIT Autoencoder\n",
        "num_samples = 5\n",
        "random_indices = np.random.randint(0, len(x_test), num_samples)\n",
        "sample_images = x_test[random_indices]\n",
        "reconstructed_images = autoencoder.predict(sample_images)\n",
        "\n",
        "plt.figure(figsize=(10, 4))\n",
        "\n",
        "for i in range(num_samples):\n",
        "    plt.subplot(2, num_samples, i + 1)\n",
        "    plt.imshow(sample_images[i])\n",
        "    plt.title(\"Original\")\n",
        "    plt.axis('off')\n",
        "\n",
        "    plt.subplot(2, num_samples, num_samples + i + 1)\n",
        "    plt.imshow(reconstructed_images[i])\n",
        "    plt.title(\"Reconstructed\")\n",
        "    plt.axis('off')\n",
        "\n",
        "plt.show()\n"
      ],
      "metadata": {
        "id": "3mIQJv2dHycx"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Variational Autoencoder (VAE) using a Vision Transformer (ViT) as the encoder\n",
        "# Variational Autoencoder (VAE) using a Vision Transformer (ViT) as the encoder\n",
        "import tensorflow as tf\n",
        "from tensorflow.keras.layers import Input, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape, LayerNormalization\n",
        "from tensorflow.keras.models import Model\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import os\n",
        "import urllib.request\n",
        "import tarfile\n",
        "import shutil\n",
        "from vit_keras import vit\n",
        "\n",
        "# Download STL-10 dataset (if not already downloaded)\n",
        "url = 'http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz'\n",
        "file_name = 'stl10_binary.tar.gz'\n",
        "\n",
        "if not os.path.exists(file_name):\n",
        "    urllib.request.urlretrieve(url, file_name)\n",
        "\n",
        "# Extract the dataset (if not already extracted)\n",
        "data_dir = 'stl10_binary'\n",
        "file_names = ['train_X.bin', 'train_y.bin', 'test_X.bin', 'test_y.bin']\n",
        "\n",
        "if os.path.exists(data_dir):\n",
        "    shutil.rmtree(data_dir)  # Delete the existing folder if it exists\n",
        "\n",
        "tar = tarfile.open(file_name, \"r:gz\")\n",
        "tar.extractall()\n",
        "tar.close()\n",
        "\n",
        "# Load the dataset\n",
        "x_train_path = os.path.join(data_dir, file_names[0])\n",
        "y_train_path = os.path.join(data_dir, file_names[1])\n",
        "x_test_path = os.path.join(data_dir, file_names[2])\n",
        "y_test_path = os.path.join(data_dir, file_names[3])\n",
        "\n",
        "x_train = np.fromfile(x_train_path, dtype=np.uint8).reshape(-1, 3, 96, 96).transpose(0, 2, 3, 1)\n",
        "y_train = np.fromfile(y_train_path, dtype=np.uint8) - 1\n",
        "x_test = np.fromfile(x_test_path, dtype=np.uint8).reshape(-1, 3, 96, 96).transpose(0, 2, 3, 1)\n",
        "y_test = np.fromfile(y_test_path, dtype=np.uint8) - 1\n",
        "\n",
        "# Preprocessing: normalize the data\n",
        "x_train = x_train.astype('float32') / 255.0\n",
        "x_test = x_test.astype('float32') / 255.0\n",
        "\n",
        "# Define the VAE architecture with Vision Transformer\n",
        "latent_dim = 128\n",
        "\n",
        "# Load the pre-trained Vision Transformer model weights\n",
        "vit_weights_path = 'ViT-L_32_imagenet21k+imagenet2012.npz'\n",
        "\n",
        "# Encoder\n",
        "inputs = Input(shape=(96, 96, 3))\n",
        "x = vit.vit_l32(image_size=96, activation='relu', pretrained=False, include_top=False, pretrained_top=False)(inputs)\n",
        "x = LayerNormalization(epsilon=1e-6)(x)\n",
        "x = Flatten()(x)\n",
        "x = Dense(256, activation='relu')(x)\n",
        "# Latent space\n",
        "z_mean = Dense(latent_dim)(x)\n",
        "z_log_var = Dense(latent_dim)(x)\n",
        "\n",
        "# Reparameterization trick\n",
        "def sampling(args):\n",
        "    z_mean, z_log_var = args\n",
        "    epsilon = tf.random.normal(shape=(tf.shape(z_mean)[0], latent_dim))\n",
        "    return z_mean + tf.exp(0.5 * z_log_var) * epsilon\n",
        "\n",
        "z = tf.keras.layers.Lambda(sampling)([z_mean, z_log_var])\n",
        "\n",
        "# Decoder\n",
        "decoder_inputs = Input(shape=(latent_dim,))\n",
        "x = Dense(6 * 6 * 32, activation='relu')(decoder_inputs)\n",
        "x = Reshape((6, 6, 32))(x)\n",
        "x = Conv2DTranspose(32, (3, 3), activation='relu', strides=(2, 2), padding='same')(x)\n",
        "x = Conv2DTranspose(16, (3, 3), activation='relu', strides=(2, 2), padding='same')(x)\n",
        "outputs = Conv2DTranspose(3, (3, 3), activation='sigmoid', padding='same')(x)\n",
        "\n",
        "# VAE model\n",
        "encoder = Model(inputs, z_mean)\n",
        "decoder = Model(decoder_inputs, outputs)\n",
        "vae_output = decoder(z)\n",
        "vae = Model(inputs, vae_output)\n",
        "\n",
        "# Reshape inputs and outputs\n",
        "inputs_reshaped = tf.image.resize(inputs, (24, 24))\n",
        "vae_output_reshaped = tf.image.resize(vae_output, (24, 24))\n",
        "\n",
        "# Define the loss function\n",
        "reconstruction_loss = tf.keras.losses.binary_crossentropy(tf.reshape(inputs_reshaped, (-1, 24 * 24 * 3)),\n",
        "                                                          tf.reshape(vae_output_reshaped, (-1, 24 * 24 * 3)))\n",
        "\n",
        "reconstruction_loss *= 24 * 24 * 3\n",
        "kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)\n",
        "kl_loss = tf.reduce_mean(kl_loss, axis=-1)\n",
        "kl_loss *= -0.5\n",
        "\n",
        "vae_loss = tf.reduce_mean(reconstruction_loss + kl_loss)\n",
        "\n",
        "# Compile the model\n",
        "vae.add_loss(vae_loss)\n",
        "vae.compile(optimizer='adam')\n",
        "\n",
        "# Train the model\n",
        "epochs = 100\n",
        "batch_size = 128\n",
        "history = vae.fit(x_train, epochs=epochs, batch_size=batch_size, validation_data=(x_test, None))\n",
        "\n",
        "# Store the training and validation loss (reconstruction loss) in the history object\n",
        "history.history['train_reconstruction_loss'] = history.history['loss']\n",
        "history.history['val_reconstruction_loss'] = history.history['val_loss']\n",
        "\n",
        "# Plot Training and Validation Loss\n",
        "train_loss = history.history['loss']\n",
        "val_loss = history.history['val_loss']\n",
        "\n",
        "plt.figure(figsize=(10, 6))\n",
        "plt.plot(range(1, epochs+1), train_loss, label='Train Loss')\n",
        "plt.plot(range(1, epochs+1), val_loss, label='Validation Loss')\n",
        "plt.xlabel('Epochs')\n",
        "plt.ylabel('Loss')\n",
        "plt.title('Training and Validation Loss')\n",
        "plt.legend()\n",
        "plt.grid(True)\n",
        "plt.show()\n"
      ],
      "metadata": {
        "id": "d9RkMYojIGm9"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}